package logs

import (
	"context"
	"errors"
	"go.uber.org/zap"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	gormUtils "gorm.io/gorm/utils"
	"time"
)

//要实现以下方法
//wiki:https://learnku.com/docs/gorm/v2/logger/9761#a1f508
//type Interface interface {
//	LogMode(LogLevel) Interface
//	Info(context.Context, string, ...interface{})
//	Warn(context.Context, string, ...interface{})
//	Error(context.Context, string, ...interface{})
//	Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
//}
type MysqlLogger struct {
	EnableNormal bool // 是否开始普通日志
}

func (d MysqlLogger) LogMode(level logger.LogLevel) logger.Interface {
	d.LogMode(level)
	return d
}

func (d MysqlLogger) Info(ctx context.Context, s string, i ...interface{}) {
	zapLogger.Info("sql-info", zap.String("s", s), zap.Reflect("i", i))
}

func (d MysqlLogger) Warn(ctx context.Context, s string, i ...interface{}) {
	zapLogger.Info("warn", zap.String("s", s), zap.Reflect("i", i))
}

func (d MysqlLogger) Error(ctx context.Context, s string, i ...interface{}) {
	zapLogger.Info("error", zap.String("s", s), zap.Reflect("i", i))
}

//
// Trace
//  @Description: 记录mysql的日志
//  @receiver d 自定义 gorm日志
//  @param ctx
//  @param begin
//  @param fc
//  @param err
//
func (d MysqlLogger) Trace(
	ctx context.Context,
	begin time.Time,
	fc func() (sql string, rowsAffected int64), err error) {

	elapsed := time.Since(begin)
	sql, rows := fc()

	zapFields := []zap.Field{
		zap.String("sql", sql),
		zap.Int64("rows", rows),
		zap.Float64("elapsed/ms", float64(elapsed.Nanoseconds())/1e6),
		zap.String("file", gormUtils.FileWithLineNum()),
		zap.Error(err),
	}

	switch {
	case err != nil && (!errors.Is(err, gorm.ErrRecordNotFound)):
		//错误
		zapFields = append(zapFields, zap.Error(err))
		GetLogger().Error("sql-err", zapFields...)
	case elapsed > 200*time.Millisecond:
		//慢sql
		GetLogger().Warn("sql-slow", zapFields...)
	default:
		//普通信息
		if d.EnableNormal {
			GetLogger().Info("sql-info", zapFields...)
		}
	}
}
